{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Validate dataset\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googlecolab/colabtools/blob/master/notebooks/colab-github-demo.ipynb)\n",
    "\n",
    "This notebook is for checking if the dataset loading functions work as intended. We check the frames loading, the features and the image preprocessing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install --upgrade wormpose"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If using Google Colab, please **restart the runtime** after installing the package (click on the menu Runtime > Restart runtime)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We first download some utils functions to display images:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://raw.githubusercontent.com/iteal/wormpose/main/examples/ipython_utils.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download sample data\n",
    "Download the sample data, or skip to use another dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_data_root = 'wormpose_data'\n",
    "import os, shutil\n",
    "if os.path.exists(sample_data_root):\n",
    "    shutil.rmtree(sample_data_root)\n",
    "os.mkdir(sample_data_root)\n",
    "!git clone https://github.com/iteal/wormpose_data.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set inputs\n",
    "\n",
    "We load the sample_data dataset, update \"dataset_loader\" and \"dataset_path\" for another dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from wormpose.dataset.loader import load_dataset\n",
    "\n",
    "# We have different loaders for different datasets, we use \"sample_data\" for the tutorial data,\n",
    "# replace with \"tierpsy\" for Tierpsy tracker data, or with your custom dataset loader name\n",
    "dataset_loader = 'sample_data'\n",
    "\n",
    "# Set the path to the dataset,\n",
    "# for Tierpsy tracker data this will be the root path of a folder containing subfolders for each videos\n",
    "dataset_path = \"wormpose_data/datasets/sample_data\"\n",
    "\n",
    "# Set if the worm is lighter than the background in the image\n",
    "# In the sample data, the worm is darker so we set this variable to False\n",
    "worm_is_lighter = False\n",
    "\n",
    "# This function loads the dataset\n",
    "# optional fields: there is an optional resize parameter to resize the images\n",
    "# also you can select specific videos from the dataset instead of loading them all\n",
    "dataset = load_dataset(dataset_loader, dataset_path, worm_is_lighter=worm_is_lighter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The sample data only contains one video, for another dataset, update \"video_name\" to choose a specific video in the dataset. \n",
    "\n",
    "We choose which frames to display. Update the variables \"start\" \"end\" \"step\" to visualize a different frame range.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "video_names = dataset.video_names\n",
    "print(f\"There are {len(video_names)} video(s) in the dataset: \\n{video_names}\")\n",
    "\n",
    "if len(video_names) == 0:\n",
    "    raise ValueError(\"No video found in dataset, check the path or the loading functions.\")\n",
    "    \n",
    "video_name = video_names[0]\n",
    "print(f\"\\nWe now inspect one video: \\\"{video_name}\\\", change the value of video_name to inspect another video.\")\n",
    "\n",
    "MAX_FRAMES = 100\n",
    "with dataset.frames_dataset.open(video_name) as frames:\n",
    "    step = max(1, len(frames) // MAX_FRAMES)\n",
    "    start, end = 0, len(frames)   \n",
    "print(f\"\\nWe inspect the frame range [{start}:{end}:{step}], change the value of start, end or step to inspect another frame range.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check frames reader\n",
    "\n",
    "Run this cell to check if the frames loading is working as intended, this should display the raw frames from the dataset, of the frame range defined above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ipython_utils import ImagesViewer\n",
    "\n",
    "img_viewer = ImagesViewer()\n",
    "with dataset.frames_dataset.open(video_name) as frames:\n",
    "    for frame in frames[start:end:step]:\n",
    "        img_viewer.add_image(frame)\n",
    "        \n",
    "img_viewer.view_as_slider()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check features\n",
    "\n",
    "Run the following cells to check the the features are consistent.\n",
    "First, we look at the average worm length of all videos and see if they are all about the same size. The algorithm will be more accurate if all worms in the dataset have similar properties."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "print(\"Listing worm lengths for all videos (pixels):\")\n",
    "\n",
    "for video_name in video_names:\n",
    "    features = dataset.features_dataset[video_name]\n",
    "    worm_length = features.measurements['worm_length']\n",
    "    average_worm_length = np.nanmean(worm_length)\n",
    "    print(f\"{video_name}: {average_worm_length:.1f}\")\n",
    "    \n",
    "print(f\"\\nThe global image size is set to : {dataset.image_shape} pixels. \\nWe will crop real images to this size and generate synthetic images of this size.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run the next cell to check if the skeleton and worm width are accurate.\n",
    "\n",
    "The skeleton should be displayed on top of the worm body in gray. The head position should be displayed as a red dot. The worm width at three positions (head, midbody, tail) should be displayed as yellow circles with the radius corresponding to the width.\n",
    "\n",
    "This only displays frames where features are available."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import cv2\n",
    "from wormpose.images.worm_drawing import draw_skeleton, draw_measurements\n",
    "from ipython_utils import ImagesViewer\n",
    "\n",
    "\n",
    "def is_valid(skel, measurements):\n",
    "    return not np.any(np.isnan(skel)) and not np.any([np.isnan(x) for x in measurements[0]])\n",
    "\n",
    "img_viewer = ImagesViewer()\n",
    "VIEW_MAX = 100\n",
    "with dataset.frames_dataset.open(video_name) as frames:\n",
    "       \n",
    "    features = dataset.features_dataset[video_name]     \n",
    "    for index, (frame, skel, measurements) in enumerate(zip(frames, \n",
    "                                                            features.skeletons,\n",
    "                                                           features.measurements)):  \n",
    "        if is_valid(skel, measurements):\n",
    "            colored_frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)\n",
    "            draw_skeleton(colored_frame, skel, color=(200, 200, 200), head_color=(0, 0, 255))\n",
    "            draw_measurements(colored_frame, skel, measurements, color=(0, 255, 255))    \n",
    "            img_viewer.add_image(colored_frame)\n",
    "            if img_viewer.count >= VIEW_MAX:\n",
    "                break\n",
    "        \n",
    "img_viewer.view_as_slider()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check image preprocessing\n",
    "\n",
    "We check if the image preprocessing is accurate.\n",
    "First we see if we can pickle it, this is necessary for multiprocessing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "try:\n",
    "    pickle.dumps(dataset.frame_preprocessing)\n",
    "    print('frame_preprocessing test passed successfully')\n",
    "except:\n",
    "    print('ERROR: frame_preprocessing is not pickable, this is needed for multiprocessing. Remove inner functions and classes from frame_preprocessing ')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we run the preprocessing on actual frames. There should be a yellow bounding box around the worm in the processed image, to validate that all non worm object pixels have been set to a uniform color."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "from wormpose.dataset.image_processing import frame_preprocessor\n",
    "from ipython_utils import ImagesViewer, display_as_slider\n",
    "\n",
    "orig_img_viewer, processed_img_viewer = ImagesViewer(), ImagesViewer()\n",
    "\n",
    "with dataset.frames_dataset.open(video_name) as frames:\n",
    "    for index, frame in enumerate(frames[start:end:step]): \n",
    "        processed_frame, _, worm_roi = frame_preprocessor.run(dataset.frame_preprocessing, frame)\n",
    "        frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)\n",
    "        processed_frame = cv2.cvtColor(processed_frame, cv2.COLOR_GRAY2BGR)\n",
    "        cv2.rectangle(processed_frame, \n",
    "                      (worm_roi[1].start, worm_roi[0].start),  \n",
    "                      (worm_roi[1].stop, worm_roi[0].stop),\n",
    "                      color=(0, 255, 255))\n",
    "        orig_img_viewer.add_image(frame)\n",
    "        processed_img_viewer.add_image(processed_frame)\n",
    "\n",
    "display_as_slider(orig_img_viewer, processed_img_viewer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## All good?\n",
    "\n",
    "If every check looks ok, you can proceed with using the dataset. The notebook tutorial_sample_data goes through the training and predict process."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
